import plotly.io as pio
pio.renderers.default = 'notebook'
from midi_rae.core import PatchState, HierarchicalPatchState, EncoderOutput
bs, dim = 32, 256
num_patch = 64
# Build fake embeddings
z1_cls = torch.randn(bs, 1, dim)
z1_patch = torch.randn(bs, num_patch, dim)
z2_cls = z1_cls + 0.1 * torch.randn(bs, 1, dim)
z2_patch = z1_patch + 0.1 * torch.randn(bs, num_patch, dim)
# Positions and masks
cls_pos = torch.tensor([[-1, -1]])
patch_pos = torch.stack([torch.tensor([r, c]) for r in range(8) for c in range(8)])
mae_mask_cls = torch.ones(1, dtype=torch.bool)
mae_mask_patch = torch.ones(num_patch, dtype=torch.bool)
ne1 = torch.ones(bs, num_patch, dtype=torch.bool)
ne2 = torch.ones(bs, num_patch, dtype=torch.bool)
ne2[16:, :] = 0 # make half empty
enc_out1 = EncoderOutput(
patches=HierarchicalPatchState(levels=[
PatchState(emb=z1_cls, pos=cls_pos, non_empty=torch.ones(bs, 1, dtype=torch.bool), mae_mask=mae_mask_cls),
PatchState(emb=z1_patch, pos=patch_pos, non_empty=ne1, mae_mask=mae_mask_patch),
]),
full_pos=torch.cat([cls_pos, patch_pos]), full_non_empty=torch.cat([torch.ones(bs,1,dtype=torch.bool), ne1], dim=1),
mae_mask=torch.cat([mae_mask_cls, mae_mask_patch]),
)
enc_out2 = EncoderOutput(
patches=HierarchicalPatchState(levels=[
PatchState(emb=z2_cls, pos=cls_pos, non_empty=torch.ones(bs, 1, dtype=torch.bool), mae_mask=mae_mask_cls),
PatchState(emb=z2_patch, pos=patch_pos, non_empty=ne2, mae_mask=mae_mask_patch),
]),
full_pos=torch.cat([cls_pos, patch_pos]), full_non_empty=torch.cat([torch.ones(bs,1,dtype=torch.bool), ne2], dim=1),
mae_mask=torch.cat([mae_mask_cls, mae_mask_patch]),
)
batch = {'file_idx': torch.arange(bs), 'deltas': torch.randint(0, 12, (bs, 2))}
figs = make_emb_viz((enc_out1, enc_out2), title='testing', batch=batch, do_umap=False, debug=True)